K-means¶

Use K-means for color compression¶

The notebook for this lesson uses the below JPEG photograph of tulips as the “dataset.” The notebook will access the data used to encode the image and use it to perform modeling.

tulips
In [6]:
import numpy as np
import pandas as pd
%matplotlib inline
import plotly.graph_objects as go
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
In [7]:
img = plt.imread('using_kmeans_for_color_compression_tulips_photo.jpg')
print(img.shape)
plt.imshow(img)
plt.axis('off')
(320, 240, 3)
Out[7]:
(-0.5, 239.5, 319.5, -0.5)
In [8]:
# Reshape the image so that each row represents a single pixel 
# defined by three values: R, G, B
img_flat = img.reshape(img.shape[0]*img.shape[1],3)
img_flat[:5,:]
Out[8]:
array([[211, 197,  38],
       [199, 181,  21],
       [178, 154,   0],
       [185, 152,   0],
       [184, 145,   0]], dtype=uint8)
In [9]:
img_flat.shape
Out[9]:
(76800, 3)
In [10]:
# Create a pandas df with r, g, and b as columns
img_flat_df = pd.DataFrame(img_flat, columns = ['r', 'g', 'b'])
img_flat_df.head()
Out[10]:
r g b
0 211 197 38
1 199 181 21
2 178 154 0
3 185 152 0
4 184 145 0
In [11]:
# Create 3D plot where each pixel in the `img` is displayed in its actual color
trace = go.Scatter3d(x = img_flat_df.r,
                     y = img_flat_df.g,
                     z = img_flat_df.b,
                     mode='markers',
                     marker=dict(size=1,
                                 color=['rgb({},{},{})'.format(r,g,b) for r,g,b 
                                        in zip(img_flat_df.r.values, 
                                               img_flat_df.g.values, 
                                               img_flat_df.b.values)],
                                 opacity=0.5))

data = [trace]

layout = go.Layout(margin=dict(l=0,
                               r=0,
                               b=0,
                               t=0),
                               )

fig = go.Figure(data=data, layout=layout)
fig.update_layout(scene = dict(
                    xaxis_title='R',
                    yaxis_title='G',
                    zaxis_title='B'),
                  )
fig.show()
In [19]:
# instantiate the model
kmeans = KMeans(n_clusters=1, n_init='auto',random_state=19991209).fit(img_flat)
In [20]:
img_flat1 = img_flat.copy()

for i in np.unique(kmeans.labels_):
    img_flat1[kmeans.labels_==i,:] = kmeans.cluster_centers_[i]
    
img1 = img_flat1.reshape(img.shape)

plt.imshow(img1)
plt.axis('off');

The result is the image of our tulips when every pixel is replaced with the average color. The average color of this photo was brown⁠—all the colors muddled together.

In [21]:
# Calculate mean of each column in the flattened array
column_means = img_flat.mean(axis=0)

print('column means: ', column_means)
column means:  [125.60802083  78.90632813  43.45473958]
In [22]:
trace = go.Scatter3d(x = img_flat_df.r,
                     y = img_flat_df.g,
                     z = img_flat_df.b,
                     mode='markers',
                     marker=dict(size=1,
                                 color=['rgb({},{},{})'.format(r,g,b) for 
                                        r,g,b in zip(img_flat_df.r.values, 
                                                     img_flat_df.g.values, 
                                                     img_flat_df.b.values)],
                                 opacity=0.5))

data = [trace]

layout = go.Layout(margin=dict(l=0,
                               r=0,
                               b=0,
                               t=0))

fig = go.Figure(data=data, layout=layout)


# Add centroid to chart
centroid = kmeans.cluster_centers_[0].tolist()

fig.add_trace(
    go.Scatter3d(x = [centroid[0]],
                 y = [centroid[1]],
                 z = [centroid[2]],
                 mode='markers',
                 marker=dict(size=7,
                             color=['rgb(125.79706706,78.8178776,42.58090169)'],
                             opacity=1))
)
fig.update_layout(scene = dict(
                    xaxis_title='R',
                    yaxis_title='G',
                    zaxis_title='B'),
                  )
fig.show()
In [25]:
kmeans3 = KMeans(n_clusters=3,n_init='auto', random_state=19991209).fit(img_flat)
# Check the unique values of what's returned by the .labels_ attribute 
np.unique(kmeans3.labels_)
Out[25]:
array([0, 1, 2])
In [26]:
# Assign centroid coordinates to `centers` variable
centers = kmeans3.cluster_centers_
centers
Out[26]:
array([[ 41.11904835,  50.27093234,  15.9247325 ],
       [202.68983875, 173.15223957, 109.8380343 ],
       [176.32140539,  42.10443038,  27.27284161]])
In [27]:
def show_swatch(RGB_value):
    '''
    Takes in an RGB value and outputs a color swatch
    '''
    R, G, B = RGB_value
    rgb = [[np.array([R,G,B]).astype('uint8')]]
    plt.figure()
    plt.imshow(rgb)
    plt.axis('off');
In [28]:
# Display the color swatches
for pixel in centers:
    show_swatch(pixel)
In [31]:
def cluster_image(k, img=img):
    '''
    Fits a K-means model to a photograph.
    Replaces photo's pixels with RGB values of model's centroids.
    Displays the updated image.

    Args:
      k:    (int)          - Your selected K-value
      img:  (numpy array)  - Your original image converted to a numpy array

    Returns:
      The output of plt.imshow(new_img), where new_img is a new numpy array \
      where each row of the original array has been replaced with the \ 
      coordinates of its nearest centroid.
    '''

    img_flat = img.reshape(img.shape[0]*img.shape[1], 3)
    kmeans = KMeans(n_clusters = k,n_init='auto', random_state = 42).fit(img_flat)
    new_img = img_flat.copy()
  
    for i in np.unique(kmeans.labels_):
        new_img[kmeans.labels_ == i, :] = kmeans.cluster_centers_[i]
  
    new_img = new_img.reshape(img.shape)

    return plt.imshow(new_img), plt.axis('off');
In [32]:
cluster_image(3);
In [33]:
print(kmeans3.labels_.shape)
print(kmeans3.labels_)
print(np.unique(kmeans3.labels_))
print(kmeans3.cluster_centers_)
(76800,)
[1 1 1 ... 2 2 2]
[0 1 2]
[[ 41.11904835  50.27093234  15.9247325 ]
 [202.68983875 173.15223957 109.8380343 ]
 [176.32140539  42.10443038  27.27284161]]
In [34]:
# Create a new column in the df that indicates the cluster number of each row 
# (as assigned by Kmeans for k=3)
img_flat_df['cluster'] = kmeans3.labels_
img_flat_df.head()
Out[34]:
r g b cluster
0 211 197 38 1
1 199 181 21 1
2 178 154 0 1
3 185 152 0 1
4 184 145 0 2
In [35]:
# Create helper dictionary to map RGB color values to each observation in df
series_conversion = {0: 'rgb' +str(tuple(kmeans3.cluster_centers_[0])),
                     1: 'rgb' +str(tuple(kmeans3.cluster_centers_[1])),
                     2: 'rgb' +str(tuple(kmeans3.cluster_centers_[2])),
                     }
series_conversion
Out[35]:
{0: 'rgb(41.119048349015216, 50.27093233589929, 15.924732501454713)',
 1: 'rgb(202.68983875095677, 173.15223957000086, 109.83803429741378)',
 2: 'rgb(176.32140538786075, 42.10443037974727, 27.27284160986635)'}
In [36]:
# Replace the cluster numbers in the 'cluster' col with formatted RGB values 
# (made ready for plotting)
img_flat_df['cluster'] = img_flat_df['cluster'].map(series_conversion)
img_flat_df.head()
Out[36]:
r g b cluster
0 211 197 38 rgb(202.68983875095677, 173.15223957000086, 10...
1 199 181 21 rgb(202.68983875095677, 173.15223957000086, 10...
2 178 154 0 rgb(202.68983875095677, 173.15223957000086, 10...
3 185 152 0 rgb(202.68983875095677, 173.15223957000086, 10...
4 184 145 0 rgb(176.32140538786075, 42.10443037974727, 27....
In [37]:
trace = go.Scatter3d(x = img_flat_df.r,
                     y = img_flat_df.g,
                     z = img_flat_df.b,
                     mode='markers',
                     marker=dict(size=1,
                                 color=img_flat_df.cluster,
                                 opacity=1))

data = trace

layout = go.Layout(margin=dict(l=0,
                               r=0,
                               b=0,
                               t=0))

fig = go.Figure(data=data, layout=layout)
fig.show()

Cluster the data, K=2-10¶

In [47]:
def cluster_image_grid(k, ax, img=img):
    '''
    Fits a K-means model to a photograph.
    Replaces photo's pixels with RGB values of model's centroids.
    Displays the updated image on an axis of a figure.

    Args:
      k:    (int)          - Your selected K-value
      ax:   (int)          - Index of the axis of the figure to plot to
      img:  (numpy array)  - Your original image converted to a numpy array

    Returns:
      A new image where each row of img's array has been replaced with the \ 
      coordinates of its nearest centroid. Image is assigned to an axis that \
      can be used in an image grid figure.
    '''
    img_flat = img.reshape(img.shape[0]*img.shape[1], 3)
    kmeans = KMeans(n_clusters=k, n_init='auto', random_state=42).fit(img_flat)
    new_img = img_flat.copy()

    for i in np.unique(kmeans.labels_):
        new_img[kmeans.labels_==i, :] = kmeans.cluster_centers_[i]

    new_img = new_img.reshape(img.shape)
    ax.imshow(new_img)
    ax.axis('off')

fig, axs = plt.subplots(3, 3) # Create 3 x 3 Canvas
fig = matplotlib.pyplot.gcf()
fig.set_size_inches(9, 12)
axs = axs.flatten()
k_values = np.arange(2, 11)
for i, k in enumerate(k_values):
    cluster_image_grid(k, axs[i], img=img)
    axs[i].title.set_text('k=' + str(k))
In [ ]: